{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate Attribution Graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import operator\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "from time import time\n",
    "import lucid.modelzoo.vision_models as models\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions to get general information of inception_v1 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_layers(graph_nodes):\n",
    "    '''\n",
    "    Get all layers\n",
    "    * input\n",
    "        - graph_nodes: tensorflow graph nodes\n",
    "    * output\n",
    "        - layers: list of all layers\n",
    "    '''\n",
    "    layers = []\n",
    "    for n in graph_nodes:\n",
    "        node_name = n.name\n",
    "        if node_name[-2:] == '_w':\n",
    "            layer = node_name.split('_')[0]\n",
    "            if layer not in layers:\n",
    "                layers.append(layer)\n",
    "    return layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_channel_sizes(layer, weight_nodes):\n",
    "    '''\n",
    "    Get channel sizes\n",
    "    * input\n",
    "        - layer: the name of layer\n",
    "        - weight_nodes: tensorflow nodes for all filters\n",
    "    * output\n",
    "        - channel_sizes: list of channel size for all pre-concatenated blocks\n",
    "    '''\n",
    "    \n",
    "    channel_sizes = [get_shape_of_node(n)[0] for n in weight_nodes if layer in n.name and '_b' == n.name[-2:] and 'bottleneck' not in n.name]\n",
    "    return channel_sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_shape_of_node(n):\n",
    "    '''\n",
    "    Get the shape of the tensorflow node\n",
    "    * input\n",
    "        - n: tensorflow node\n",
    "    * output\n",
    "        - tensor_shape: shape of n\n",
    "    '''\n",
    "    dims = n.attr['value'].tensor.tensor_shape.dim\n",
    "    tensor_shape = [d.size for d in dims]\n",
    "    return tensor_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prev_layer(layer, mixed_layers):\n",
    "    layer_idx = mixed_layers.index(layer)\n",
    "    return mixed_layers[layer_idx - 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions to extract influences from I-matrices for a class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_class_I_matrices(Is, all_layers, start_layer, end_layer, pred_class, verbose=True):\n",
    "    '''\n",
    "    Extract influences for a class from I-matrices\n",
    "    * input\n",
    "        - Is: I-matrices for all class\n",
    "        - all_layers: list of all layers\n",
    "        - start_layer: start layer (towards output)\n",
    "        - end_layer: end layer (towards input)\n",
    "        - pred_class: predicted class\n",
    "    * output\n",
    "        - Is_class: I-matrices for a class. a dictionary, where\n",
    "            - key: layer name (e.g. 'mixed4d', 'mixed4d_1')\n",
    "            - val: influences for given layer(key), class(argument of the function)\n",
    "    '''\n",
    "\n",
    "    # Get layers starting from the given layer to the input layer\n",
    "    start_idx, end_idx = all_layers.index(start_layer), all_layers.index(end_layer)\n",
    "    target_layers = all_layers[start_idx: end_idx - 1: -1]\n",
    "    \n",
    "    Is_class = {}\n",
    "    \n",
    "    for layer in target_layers:\n",
    "        if verbose:\n",
    "            print('\\n({}) loading {}'.format(pred_class, layer), end='')\n",
    "        Is_class[layer] = Is[layer][pred_class]\n",
    "        for branch in [1, 2]:\n",
    "            inner_layer = '{}_{}'.format(layer, branch)\n",
    "            if verbose:\n",
    "                print(',', inner_layer, end='')\n",
    "            Is_class[inner_layer] = Is[inner_layer][pred_class]\n",
    "    if verbose:\n",
    "        print('\\n')\n",
    "    \n",
    "    return Is_class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_I_matrices(all_layers, start_layer, end_layer, I_mat_dirpath, verbose=True):\n",
    "    '''\n",
    "    Load I-matrices for all layers\n",
    "    * input\n",
    "        - all_layers: list of all layers\n",
    "        - start_layer: start layer (towards output)\n",
    "        - end_layer: end layer (towards input)\n",
    "        - I_mat_dirpath: directory path of I-matrices\n",
    "    * output\n",
    "        - Is: I-matrices for all class\n",
    "    '''\n",
    "    \n",
    "    # Get layers starting from the given layer to the input layer\n",
    "    start_idx, end_idx = all_layers.index(start_layer), all_layers.index(end_layer)\n",
    "    target_layers = all_layers[start_idx: end_idx - 1: -1]\n",
    "\n",
    "    # Load I matrices\n",
    "    Is = {}\n",
    "    for layer in target_layers:\n",
    "        if verbose:\n",
    "            print('\\n(all) loading {}'.format(layer), end='')\n",
    "        Is[layer] = load_inf_matrix(I_mat_dirpath, layer)\n",
    "        for branch in [1, 2]:\n",
    "            inner_layer = '{}_{}'.format(layer, branch)\n",
    "            if verbose:\n",
    "                print(',', inner_layer, end='')\n",
    "            Is[inner_layer] = load_inf_matrix(I_mat_dirpath, inner_layer)\n",
    "    if verbose:\n",
    "        print('\\n')    \n",
    "    return Is"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_inf_matrix(I_mat_dirpath, layer):\n",
    "    '''\n",
    "    Load I matrix for a layer\n",
    "    * input\n",
    "        - mat_dirpath: directory path of I-matrices\n",
    "        - layer: layer name\n",
    "    * output\n",
    "        - I_mat: I-matrix of the given layer\n",
    "    '''\n",
    "    if I_mat_dirpath[-1] == '/':\n",
    "        filepath = I_mat_dirpath + 'I_' + layer + '.json'\n",
    "    else:\n",
    "        filepath = I_mat_dirpath + '/I_' + layer + '.json'\n",
    "        \n",
    "    with open(filepath) as f:\n",
    "        I_mat = json.load(f)\n",
    "    \n",
    "    return I_mat"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions to extract A matrix information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_A(A_mat_dirpath, layer, A_prob_threshold):\n",
    "    A = np.loadtxt(A_mat_dirpath + 'A-' + str(A_prob_threshold) + '-' + layer + '.csv', delimiter=',', dtype=int)\n",
    "    return A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_As(A_mat_dirpath, mixed_layers, A_prob_threshold):\n",
    "    As = {}\n",
    "    for layer in mixed_layers:\n",
    "        A = read_A(A_mat_dirpath, layer, A_prob_threshold)\n",
    "        As[layer] = A\n",
    "    return As"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions to query the influence values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_branch(layer, channel, layer_channels):\n",
    "    '''\n",
    "    Get branch of the channel in the layer\n",
    "    * input\n",
    "        - layer: the name of layer\n",
    "        - channel: channel in the layer\n",
    "        - layer_channels: fragment sizes of the layer\n",
    "    * output\n",
    "        - branch: branch of the channel\n",
    "    '''\n",
    "    \n",
    "    channels = layer_channels[:]\n",
    "    for i in range(len(channels) - 1):\n",
    "        channels[i + 1] += channels[i]\n",
    "        \n",
    "    branch = np.searchsorted(channels, channel, side='right')\n",
    "    \n",
    "    return branch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_num_of_prevs_for_a_channel(layer, Is_class):\n",
    "    '''\n",
    "    Get the average number of previous channels connected to a channel in the given layer\n",
    "    * input\n",
    "        - layer: layer\n",
    "        - Is_class: I-matrices for a class\n",
    "    * output\n",
    "        - num_avg: the average number of connections for a channel in the given layer\n",
    "    '''\n",
    "    \n",
    "    num_of_channel_edges = []\n",
    "    \n",
    "    for channel, prev_inf_dict in enumerate(Is_class[layer]):\n",
    "        # Get branch\n",
    "        branch = get_branch(layer, channel, layer_fragment_sizes[layer])\n",
    "        \n",
    "        if branch in [0, 3]:\n",
    "            num_of_channel_edges.append(len(prev_inf_dict))\n",
    "            \n",
    "    num_avg = int(np.average(num_of_channel_edges))\n",
    "    \n",
    "    return num_avg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions to generate the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_node_name(layer, channel):\n",
    "    return layer + '-' + str(channel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_full_graph(Is_class, G, mixed_layers, outlier_nodes):\n",
    "    \n",
    "    # Add edges into G from Is_class\n",
    "    for layer_idx, layer in enumerate(mixed_layers[::-1][:-1]):\n",
    "        # Get previous layer\n",
    "        prev_layer = mixed_layers[::-1][layer_idx + 1]\n",
    "        \n",
    "        # Get the average number of edges for a channel\n",
    "        avg_num_edges = avg_num_of_prevs_for_a_channel(layer, Is_class)\n",
    "        \n",
    "        # For all channels in layer\n",
    "        for channel, prev_inf_dict in enumerate(Is_class[layer]):\n",
    "            # Get source node\n",
    "            src = get_node_name(layer, channel)\n",
    "            \n",
    "            # Get branch\n",
    "            branch = get_branch(layer, channel, layer_fragment_sizes[layer])\n",
    "            \n",
    "            # If the channel is connected to a branch\n",
    "            if branch in [1, 2]:\n",
    "                # Get possible edge weights for the channel\n",
    "                channel_edges = {}\n",
    "                for prev_channel in prev_inf_dict:\n",
    "                    prev_inf = prev_inf_dict[prev_channel]\n",
    "                    \n",
    "                    # Extract influence information for prev_channel\n",
    "                    branch_layer = '{}_{}'.format(layer, branch)\n",
    "                    prev_prev_inf_dict = Is_class[branch_layer][int(prev_channel)]\n",
    "                    \n",
    "                    for prev_prev_channel in prev_prev_inf_dict:\n",
    "                        prev_prev_inf = prev_prev_inf_dict[prev_prev_channel]\n",
    "                        if prev_prev_channel not in channel_edges:\n",
    "                            channel_edges[prev_prev_channel] = []\n",
    "                        channel_edges[prev_prev_channel].append(min(prev_inf, prev_prev_inf))\n",
    "                        \n",
    "                # Get only one weight for each channel and prev_prev channel\n",
    "                for prev_prev_channel in channel_edges:\n",
    "                    channel_edges[prev_prev_channel] = max(channel_edges[prev_prev_channel])\n",
    "                \n",
    "                # Get top (avg_num_edges) prev_prev_channels based on the edge weight\n",
    "                top_prev_prevs_weights = sorted(channel_edges.items(), key=operator.itemgetter(1), reverse=True)\n",
    "                top_prev_prevs_weights = top_prev_prevs_weights[:avg_num_edges]\n",
    "                \n",
    "                # Add edges from channel and top_prev_prev_channel\n",
    "                for prev_prev_channel, weight in top_prev_prevs_weights:\n",
    "                    tgt = get_node_name(prev_layer, prev_prev_channel)\n",
    "                    G.add_edge(src, tgt, weight=weight)\n",
    "            \n",
    "            # If the channel is directly connected to the previous layer\n",
    "            elif branch in [0, 3]:\n",
    "                for prev_channel in prev_inf_dict:\n",
    "                    # Add edge of (src, tgt-prev)\n",
    "                    prev_inf = prev_inf_dict[prev_channel]\n",
    "                    tgt = get_node_name(prev_layer, prev_channel)\n",
    "                    G.add_edge(src, tgt, weight=prev_inf)\n",
    "                    \n",
    "    # remove outlier nodes and edges from full graph\n",
    "    for outlier in outlier_nodes:\n",
    "        G.remove_node(outlier)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_dag(mixed_layers):\n",
    "    dag = {}\n",
    "    for layer in mixed_layers[::-1]:\n",
    "        dag[layer] = []\n",
    "    return dag"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions for pagerank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_personalization_dict(G, As, mixed_layers, pred_class, outlier_nodes):\n",
    "    '''\n",
    "    Get personalization dictionary\n",
    "    * input\n",
    "        - G: graph\n",
    "        - As: A matrices\n",
    "        - mixed_layers: layers starting with 'mixed'\n",
    "    '''\n",
    "    \n",
    "    personalization = {node: 1 for node in list(G.nodes)}\n",
    "\n",
    "    for layer in mixed_layers[::-1]:\n",
    "        A = As[layer][pred_class]\n",
    "        max_a = -100\n",
    "        for channel, val in enumerate(A):\n",
    "            node = get_node_name(layer, channel)\n",
    "            if node not in outlier_nodes:\n",
    "                max_a = max(max_a, val)\n",
    "                \n",
    "        for channel in range(A.shape[-1]):\n",
    "            node = layer + '-' + str(channel)\n",
    "            if node in personalization:\n",
    "                personalization[get_node_name(layer, channel)] = A[channel] / max_a\n",
    "    \n",
    "    return personalization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions for thresholding nodes, edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prob_mass(prob_mass_threshold, reverse_sorted_vals):\n",
    "\n",
    "    prob_mass = 0\n",
    "    threshold_cnt = 0\n",
    "    sum_val = np.sum(reverse_sorted_vals)\n",
    "    \n",
    "    while prob_mass < prob_mass_threshold:\n",
    "        prob_mass += reverse_sorted_vals[threshold_cnt] / sum_val\n",
    "        threshold_cnt += 1\n",
    "\n",
    "    threshold_val = reverse_sorted_vals[threshold_cnt]\n",
    "    return threshold_cnt, threshold_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_threshold(mixed_layers, pagerank, prob_mass_thresholds, unified_threshold=True):\n",
    "    '''\n",
    "    Get threshold dictionary\n",
    "    * input\n",
    "        - mixed_layers: layers starting with mixed\n",
    "        - pagerank: pagerank dictionary\n",
    "        - prob_mass_thresholds: probability mass thresholds for each layer\n",
    "        - unified_threshold: whether the same threshold is used for all layers\n",
    "    * output\n",
    "        - thresholds: threshold dictionary, whose \n",
    "            - key: layer\n",
    "            - val: threshold criteria value\n",
    "    '''\n",
    "    \n",
    "    thresholds = {}\n",
    "    prob_mass_threshold = list(prob_mass_thresholds.values())[0]\n",
    "    \n",
    "    if unified_threshold:\n",
    "        # Get threshold value for all layer\n",
    "        threshold_val = get_threshold_val(mixed_layers, pagerank, prob_mass_threshold=prob_mass_threshold)\n",
    "    \n",
    "        # Get thresholds\n",
    "        for layer in mixed_layers[::-1]:\n",
    "            thresholds[layer] = threshold_val\n",
    "    \n",
    "    else:\n",
    "        for layer in mixed_layers[::-1]:\n",
    "            pagerank_layer = {node: pagerank[node] for node in pagerank if node.split('-')[0] == layer}\n",
    "            threshold_val_layer = get_threshold_val(mixed_layers, pagerank_layer, prob_mass_threshold=prob_mass_thresholds[layer])\n",
    "            thresholds[layer] = threshold_val_layer\n",
    "\n",
    "    return thresholds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_threshold_val(mixed_layers, pagerank, prob_mass_threshold=0.12):\n",
    "    pagerank_values = list(pagerank.values())\n",
    "    sorted_pagerank_vals = sorted(pagerank_values, reverse=True)\n",
    "    threshold_cnt, threshold_val = get_prob_mass(prob_mass_threshold, sorted_pagerank_vals)\n",
    "    \n",
    "    return threshold_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_thresholded_nodes(pagerank, thresholds, outlier_nodes):\n",
    "    thresholded_nodes = {}\n",
    "    for node in pagerank:\n",
    "        if node in outlier_nodes:\n",
    "            continue\n",
    "            \n",
    "        layer, channel = node.split('-')\n",
    "        if pagerank[node] > thresholds[layer]:\n",
    "            thresholded_nodes[node] = pagerank[node]\n",
    "        \n",
    "    return thresholded_nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_thresholded_edges(mixed_layers, G, thresholded_nodes):\n",
    "    '''\n",
    "    Get thresholded edges\n",
    "    * input\n",
    "        - mixed_layers: all mixed layers\n",
    "        - G: graph\n",
    "        - thresholded_nodes: nodes whose pagerank value is higher than threshold\n",
    "    * output\n",
    "        - thresholded_edges: edges connected by both thresholded nodes\n",
    "    '''\n",
    "    \n",
    "    # Initialize thresholded_edges\n",
    "    thresholded_edges = {}\n",
    "    edge_checker = set()\n",
    "    for layer in mixed_layers[::-1]:\n",
    "        thresholded_edges[layer] = {}\n",
    "\n",
    "    for node in thresholded_nodes:\n",
    "        for edge in G.edges(node):\n",
    "            node1, node2 = edge\n",
    "            if (node1 not in thresholded_nodes) or (node2 not in thresholded_nodes):\n",
    "                continue\n",
    "\n",
    "            layer, channel, prev_layer, prev_channel = parse_edge(edge, mixed_layers)\n",
    "            if channel not in thresholded_edges[layer]:\n",
    "                thresholded_edges[layer][channel] = []\n",
    "\n",
    "            if (node1, node2) in edge_checker:\n",
    "                continue\n",
    "            elif (node2, node1) in edge_checker:\n",
    "                continue\n",
    "            edge_checker.add((node1, node2))\n",
    "\n",
    "            thresholded_edges[layer][channel].append({\n",
    "                'prev_channel': int(prev_channel),\n",
    "                'inf': G.get_edge_data(*edge)['weight']\n",
    "            })\n",
    "\n",
    "    return thresholded_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prev_nodes_with_valid_edges(thresholded_edges, thresholded_nodes, prev_layer):\n",
    "    if prev_layer == 'mixed3a':\n",
    "        mixed3a_nodes = [x.split('-')[1] for x in thresholded_nodes if 'mixed3a' in x]\n",
    "        return mixed3a_nodes\n",
    "    else:\n",
    "        return list(thresholded_edges[prev_layer].keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions to generate DAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_edge(edge, mixed_layers):\n",
    "    n1_layer, n1_channel = edge[0].split('-')\n",
    "    n2_layer, n2_channel = edge[1].split('-')\n",
    "    \n",
    "    n1_idx, n2_idx = mixed_layers.index(n1_layer), mixed_layers.index(n2_layer)\n",
    "    \n",
    "    # If n1 is current layer, n2 is previous layer\n",
    "    if n1_idx > n2_idx:\n",
    "        layer, channel = n1_layer, n1_channel\n",
    "        prev_layer, prev_channel = n2_layer, n2_channel\n",
    "        \n",
    "    # If n1 is previous layer, n1 is current layer\n",
    "    else:\n",
    "        layer, channel = n2_layer, n2_channel\n",
    "        prev_layer, prev_channel = n1_layer, n1_channel\n",
    "    \n",
    "    return layer, channel, prev_layer, prev_channel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_dag(mixed_layers, thresholded_nodes, thresholded_edges, pagerank, As, Is, pred_class):\n",
    "    \n",
    "    # Initialize dag, check_channel, layer_validity\n",
    "    dag = {}\n",
    "    check_channel = {}\n",
    "    layer_validity = {}\n",
    "    for layer in mixed_layers[::-1]:\n",
    "        dag[layer] = []\n",
    "        check_channel[layer] = set()\n",
    "        if layer != 'mixed3a':\n",
    "            layer_validity[layer] = False\n",
    "    \n",
    "    # Mixed3a\n",
    "    layer = 'mixed3a'\n",
    "    for channel, cnt in enumerate(As['mixed3a'][pred_class]):\n",
    "        node_name = get_node_name(layer, channel)\n",
    "        if node_name in thresholded_nodes:\n",
    "            dag[layer].append({\n",
    "                'channel': int(channel),\n",
    "                'count': int(cnt),\n",
    "                'layer': layer,\n",
    "                'pagerank': pagerank[node_name],\n",
    "                'prev_channels': [],\n",
    "                'attr_channels': []\n",
    "            })\n",
    "            \n",
    "    # Other layers\n",
    "    for node in thresholded_nodes:\n",
    "        for edge in G.edges(node):\n",
    "            \n",
    "            # Parse the edge\n",
    "            node1, node2 = edge\n",
    "            layer, channel, prev_layer, prev_channel = parse_edge(edge, mixed_layers)\n",
    "            \n",
    "            curr_node = get_node_name(layer, channel)\n",
    "            prev_node = get_node_name(prev_layer, prev_channel)\n",
    "            \n",
    "            # Ignore unnecessary cases\n",
    "            if curr_node not in thresholded_nodes:\n",
    "                continue\n",
    "            if channel in check_channel[layer]:\n",
    "                continue\n",
    "            check_channel[layer].add(channel)\n",
    "            \n",
    "            # Read A matrices\n",
    "            A = As[layer]\n",
    "\n",
    "            # Get attributed previous channels\n",
    "            attr_channels_dict = Is[layer][pred_class][int(channel)]\n",
    "            attr_channels_dict = sorted(attr_channels_dict.items(), key=lambda x:x[1], reverse=True)[0:3]\n",
    "            attr_channels = [{'prev_channel': prev[0], 'inf': prev[1]} for prev in attr_channels_dict]\n",
    "\n",
    "            # If the channel is connected to thresholded nodes in previous layer\n",
    "            if channel in thresholded_edges[layer]:\n",
    "                # Mark the layer is valid\n",
    "                layer_validity[layer] = True\n",
    "            \n",
    "                # Get previous channels    \n",
    "                valid_prev_nodes = get_prev_nodes_with_valid_edges(thresholded_edges, thresholded_nodes, prev_layer)\n",
    "                filtered_prev_channels = list(filter(lambda x: str(x['prev_channel']) in valid_prev_nodes, thresholded_edges[layer][channel]))\n",
    "\n",
    "                # Add node into dag\n",
    "                dag[layer].append({\n",
    "                    'channel': int(channel),\n",
    "                    'count': int(A[pred_class][int(channel)]),\n",
    "                    'layer': layer,\n",
    "                    'pagerank': pagerank[curr_node],\n",
    "                    'prev_channels': filtered_prev_channels,\n",
    "                    'attr_channels': attr_channels\n",
    "                })\n",
    "            \n",
    "            # If the channel is not connected to thresholded nodes in previous layer\n",
    "            else:\n",
    "                # Add node into dag\n",
    "                dag[layer].append({\n",
    "                    'channel': int(channel),\n",
    "                    'count': int(A[pred_class][int(channel)]),\n",
    "                    'layer': layer,\n",
    "                    'pagerank': pagerank[curr_node],\n",
    "                    'prev_channels': [],\n",
    "                    'attr_channels': attr_channels\n",
    "                })\n",
    "\n",
    "    return dag, layer_validity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get inception_v1 model infromation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_dirpath = '/Users/haekyu/data/summit/'\n",
    "# data_dirpath = '/home/fred/code/summit-notebooks/data/'\n",
    "data_dirpath = '/Users/fredhohman/Github/summit-notebooks/data/'\n",
    "imgnet_dirpath = data_dirpath\n",
    "I_mat_dirpath = data_dirpath + 'I-matrices/'\n",
    "A_mat_dirpath = data_dirpath + 'A-matrices/'\n",
    "dag_dirpath = data_dirpath + 'ag/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "googlenet = models.InceptionV1()\n",
    "googlenet.load_graphdef()\n",
    "nodes = googlenet.graph_def.node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_layers = get_layers(nodes)\n",
    "mixed_layers = [layer for layer in all_layers if 'mixed' in layer]\n",
    "layer_fragment_sizes = {layer: get_channel_sizes(layer, nodes) for layer in mixed_layers}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_prob_threshold = 0.03\n",
    "with open(imgnet_dirpath + 'imagenet-' + str(A_prob_threshold) + '.json') as f:\n",
    "    imgnet = json.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run for all classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_layer = 'mixed5b'\n",
    "end_layer = 'mixed3a'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Is = load_I_matrices(all_layers, start_layer, end_layer, I_mat_dirpath, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outlier_nodes = ['mixed3a-67', 'mixed3a-190', 'mixed3b-390', 'mixed3b-399', 'mixed3b-412']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_class = 1000\n",
    "prob_mass_dict = {}\n",
    "unified_threshold = True\n",
    "plot_pagerank_values = False\n",
    "test_class_for_debugging = True\n",
    "\n",
    "# start at 7.5%, increase by 0.5% if no connecetions/channels in a layer\n",
    "prob_mass_initial = 0.075\n",
    "prob_mass_increase = 0.005\n",
    "\n",
    "for pred_class in range(num_class):\n",
    "    prob_mass_thresholds = {layer: prob_mass_initial for layer in mixed_layers}\n",
    "\n",
    "    if test_class_for_debugging:\n",
    "        if pred_class not in [55,270,294]: \n",
    "            continue\n",
    "        \n",
    "    tic = time()\n",
    "    \n",
    "    # Extract influence information for the pred_class\n",
    "    Is_class = extract_class_I_matrices(Is, all_layers, start_layer, end_layer, pred_class, verbose=False)\n",
    "    \n",
    "    # Initialize an undirected graph\n",
    "    G = nx.Graph()\n",
    "    \n",
    "    # Generate full graph\n",
    "    gen_full_graph(Is_class, G, mixed_layers, outlier_nodes)\n",
    "    \n",
    "    # Read A matrices\n",
    "    As = read_As(A_mat_dirpath, mixed_layers, A_prob_threshold)\n",
    "    \n",
    "    # Personalized pagerank to filter nodes\n",
    "    personalization = get_personalization_dict(G, As, mixed_layers, pred_class, outlier_nodes)\n",
    "    pagerank = nx.pagerank(G, personalization=personalization, weight='weight', alpha=0.85)\n",
    "    \n",
    "    if plot_pagerank_values:\n",
    "\n",
    "        for layer in mixed_layers:\n",
    "            print(layer)\n",
    "\n",
    "            pagerank_values = []\n",
    "            for node in pagerank:\n",
    "                if node.split('-')[0] == layer:\n",
    "                    pagerank_values.append(pagerank[node])\n",
    "\n",
    "            print(len(pagerank_values))\n",
    "            plt.figure(figsize=(12,3))\n",
    "            plt.hist(pagerank_values, bins=100)\n",
    "            plt.show()\n",
    "    \n",
    "    need_to_relax = True\n",
    "\n",
    "    while need_to_relax:\n",
    "        \n",
    "        # Thresolding\n",
    "        thresholds = get_threshold(mixed_layers, pagerank, prob_mass_thresholds=prob_mass_thresholds, unified_threshold=unified_threshold)\n",
    "        thresholded_nodes = get_thresholded_nodes(pagerank, thresholds, outlier_nodes=outlier_nodes)\n",
    "        thresholded_edges = get_thresholded_edges(mixed_layers, G, thresholded_nodes)\n",
    "\n",
    "        # Generate dag in json format\n",
    "        dag, layer_validity = gen_dag(mixed_layers, thresholded_nodes, thresholded_edges, pagerank, As, Is, pred_class)\n",
    "        \n",
    "        need_to_relax = False in layer_validity.values()\n",
    "        if need_to_relax:\n",
    "            if unified_threshold:\n",
    "                for layer in mixed_layers:\n",
    "                    prob_mass_thresholds[layer] += prob_mass_increase\n",
    "            else:\n",
    "                for layer_idx, layer in enumerate(mixed_layers[1:]):\n",
    "                    if not layer_validity[layer]:\n",
    "                        prev_layer = mixed_layers[layer_idx]\n",
    "                        prob_mass_thresholds[prev_layer] += prob_mass_increase\n",
    "                        prob_mass_thresholds[layer] += prob_mass_increase\n",
    "\n",
    "    # Save prob_mass_threshold\n",
    "    prob_mass_dict[pred_class] = prob_mass_thresholds\n",
    "    \n",
    "    # Save the graph into a file\n",
    "    filename = dag_dirpath + 'ag-{}.json'.format(pred_class)\n",
    "    with open(filename, 'w') as f:\n",
    "        json.dump(dag, f, indent=2)\n",
    "\n",
    "    toc = time()\n",
    "\n",
    "    print('class: %s, time: %.2lf sec' % (pred_class, toc - tic))\n",
    "#     print(nx.info(G))\n",
    "    \n",
    "filename = dag_dirpath + 'pagerank/' + 'prob-mass-threshold-{}.json'.format('unified' if unified_threshold else 'separate')\n",
    "with open(filename, 'w') as f:\n",
    "    json.dump(prob_mass_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "pg_dict = defaultdict(lambda: {'min': 100, 'max': -100})\n",
    "for node in pagerank:\n",
    "    if node in outlier_nodes:\n",
    "        continue\n",
    "    layer, channel = node.split('-')\n",
    "    val = pagerank[node]\n",
    "    pg_dict[layer] = {'min': min(pg_dict[layer]['min'], val), 'max': max(pg_dict[layer]['max'], val)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer in pg_dict:\n",
    "    a, A = pg_dict[layer]['min'], pg_dict[layer]['max']\n",
    "    print('%s, min:%.4lf, max:%.4lf' % (layer, a, A))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
